{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize the statistical predicate shifts\n",
    "\n",
    "In the paper, we use a statistical baseline that uses the statistically calculated predicate shifts instead of the learnt predicate kernels. This notebook guides you through the process of visualizing these statistical shifts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from utils.visualization_utils import get_dict\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the dataset you want to visualize."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###################\n",
    "data_type = \"clevr\"\n",
    "###################\n",
    "if data_type==\"vrd\":\n",
    "    nrows = 7\n",
    "    ncols=10\n",
    "    fontsize = 12\n",
    "    figsize = (14,20)\n",
    "    max_images = None\n",
    "    annotations_file = \"data/VRD/annotations_train.json\"\n",
    "    vocab_dir = os.path.join('data/VRD')\n",
    "elif data_type==\"clevr\":\n",
    "    nrows = 1\n",
    "    ncols = 4\n",
    "    fontsize = 22\n",
    "    figsize = (20,10)\n",
    "    max_images = 1000\n",
    "    annotations_file = \"data/clevr/annotations_train.json\"\n",
    "    vocab_dir = os.path.join('data/clevr')\n",
    "elif data_type==\"visualgenome\":\n",
    "    nrows=7\n",
    "    ncols=10\n",
    "    fontsize = 12\n",
    "    figsize = (14,20)\n",
    "    max_images = None\n",
    "    annotations_file = \"data/VisualGenome/annotations_train.json\"\n",
    "    vocab_dir = os.path.join('data/VisualGenome')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Print out all the predicate categories in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicate_dict, obj_subj_dict = get_dict(vocab_dir)\n",
    "annotations = json.load(open(annotations_file))\n",
    "print(' | '.join(predicate_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_vec(s_bbox, o_bbox):\n",
    "    s_center = ((s_bbox[0]+s_bbox[1])/2., (s_bbox[2]+s_bbox[3])/2.)\n",
    "    o_center = ((o_bbox[0]+o_bbox[1])/2., (o_bbox[2]+o_bbox[3])/2.)\n",
    "    vec = np.array([o_center[0]-s_center[0], o_center[1]-s_center[1]])\n",
    "    return s_center, vec\n",
    "def get_score(i, j, f_center, vec, eps=10e-8):\n",
    "    current_vec = np.array([i-f_center, j-f_center])\n",
    "    score = vec.dot(current_vec)/((10e-8 + max(np.linalg.norm(current_vec), np.linalg.norm(vec)))**2)\n",
    "    return score\n",
    "def get_filter(vec, s_center, filter_size=10):\n",
    "    f_center = filter_size / 2\n",
    "    filter_template = np.zeros((1, filter_size, filter_size))\n",
    "    for i in range(filter_size):\n",
    "        for j in range(filter_size):\n",
    "            filter_template[0, i, j] = get_score(i, j, f_center, vec)\n",
    "    return filter_template"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's calculate all the predicate shifts. This may take a while depending on the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shifts = {}\n",
    "inv_shifts = {}\n",
    "shift_list = {}\n",
    "inv_shift_list = {}\n",
    "progress = 0\n",
    "for img in annotations:\n",
    "    for rel in annotations[img]:\n",
    "        s_bbox = rel[\"subject\"][\"bbox\"]\n",
    "        o_bbox = rel[\"object\"][\"bbox\"]\n",
    "        s_center, vec = get_vec(s_bbox, o_bbox)\n",
    "        o_center, inv_vec = get_vec(o_bbox, s_bbox)\n",
    "        predicate = predicate_dict[rel[\"predicate\"]]\n",
    "        if predicate not in shift_list:\n",
    "            shift_list[predicate] = []\n",
    "            inv_shift_list[predicate] = []\n",
    "        shift_list[predicate].append(get_filter(vec, s_center))\n",
    "        inv_shift_list[predicate].append(get_filter(inv_vec, o_center))\n",
    "    if progress % 100 == 0:\n",
    "        print('Progress: ', progress)\n",
    "    progress += 1\n",
    "    if max_images is not None and progress >= max_images:\n",
    "        break\n",
    "for predicate in shift_list:\n",
    "    filters = np.concatenate(shift_list[predicate], axis=0).mean(axis=0)\n",
    "    inv_filters = np.concatenate(inv_shift_list[predicate], axis=0).mean(axis=0)\n",
    "    shifts[predicate] = filters\n",
    "    inv_shifts[predicate] = inv_filters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finally, let's visualize what they look like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interp_method = 'spline16'\n",
    "fig, axs = plt.subplots(nrows=nrows*2, ncols=ncols, figsize=figsize)\n",
    "fig.tight_layout()\n",
    "row = 0\n",
    "col = 0\n",
    "for predicate in predicate_dict:\n",
    "    try:\n",
    "        ax = axs[row, col]\n",
    "        im = shifts[predicate]\n",
    "        plot = ax.imshow(im, interpolation=interp_method)\n",
    "        ax.set_title(predicate, {\"fontsize\":fontsize})\n",
    "        ax.axis(\"off\")\n",
    "        ax = axs[row, col+1]\n",
    "        im = inv_shifts[predicate]\n",
    "        plot = ax.imshow(im, interpolation=interp_method)\n",
    "        ax.set_title(\"INV {}\".format(predicate), {\"fontsize\":fontsize})\n",
    "        ax.axis(\"off\")\n",
    "    except:\n",
    "        pass\n",
    "    col += 2\n",
    "    if col >= ncols:\n",
    "        row += 1\n",
    "        col = 0\n",
    "for row in range(nrows*2):\n",
    "    for col in range(ncols):\n",
    "        ax = axs[row, col]\n",
    "        ax.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
